{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a Model on Wideband for Signal Detection\n",
    "\n",
    "This notebook demonstrates how to train a YOLO model on the Wideband dataset for energy detection using spectrograms.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the Wideband Dataset for YOLO\n",
    "\n",
    "We are going to use an infinite NewWideband dataset to generate both our train and validation datasets by writing the images and labels to disk how [YOLO expects them to be](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#option-2-create-a-manual-dataset)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Variables\n",
    "from torchsig.transforms.dataset_transforms import Spectrogram\n",
    "from torchsig.transforms.target_transforms import YOLOLabel\n",
    "from torchsig.signals.signal_lists import TorchSigSignalLists\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "root = \"./datasets/wideband_detector_example\"\n",
    "fft_size = 1024\n",
    "num_iq_samples_dataset = fft_size ** 2\n",
    "class_list = TorchSigSignalLists.all_signals\n",
    "num_classes = len(class_list)\n",
    "num_train = 100 # size of train dataset\n",
    "num_val = 10 # size of validation dataset\n",
    "\n",
    "# transform data into a spectrogram image\n",
    "transforms = [Spectrogram(fft_size=fft_size)]\n",
    "# YOLO labels are expected to be (class index, x center, y center, width, height)\n",
    "# all normalized to zero, with (0,0) being upper left corner\n",
    "target_transforms = [YOLOLabel()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import WidebandMetadata\n",
    "from torchsig.datasets.datamodules import WidebandDataModule\n",
    "from torchsig.datasets.wideband import NewWideband\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "# create the NewWideband dataset\n",
    "dataset_metadata = WidebandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = 2,\n",
    "    num_signals_max = 3,\n",
    "    transforms=transforms,\n",
    "    target_transforms=target_transforms,\n",
    ")\n",
    "\n",
    "wideband = NewWideband(\n",
    "    dataset_metadata = dataset_metadata\n",
    ")\n",
    "\n",
    "# show sample from dataset\n",
    "data, label = wideband[0]\n",
    "print(f\"Data shape: {data.shape}\")\n",
    "print(f\"Number of signals: {len(label)}\")\n",
    "nl = \"\\n\"\n",
    "print(f\"Labels:\\n{nl.join(str(l) for l in label)}\")\n",
    "\n",
    "height, width = data.shape\n",
    "fig = plt.figure(figsize=(12,6))\n",
    "fig.tight_layout()\n",
    "\n",
    "ax = fig.add_subplot(1,1,1)\n",
    "pos = ax.imshow(data,aspect='auto',cmap='Wistia',vmin=dataset_metadata.noise_power_db)\n",
    "\n",
    "fig.colorbar(pos, ax=ax)\n",
    "\n",
    "\n",
    "for t in label:\n",
    "    classindex, xcenter, ycenter, normwidth, normheight = t\n",
    "\n",
    "\n",
    "    actualwidth = width * normwidth\n",
    "    actualheight = height * normheight\n",
    "\n",
    "    actualxcenter = xcenter * width\n",
    "    actualycenter = ycenter * height\n",
    "\n",
    "    x_lowerleft = actualxcenter - (actualwidth / 2)\n",
    "    y_lowerleft = actualycenter + (actualheight / 2)\n",
    "\n",
    "    # print(x_lowerleft, y_lowerleft, actualwidth, actualheight)\n",
    "\n",
    "    ax.add_patch(Rectangle(\n",
    "        (x_lowerleft, y_lowerleft), \n",
    "        actualwidth, \n",
    "        -actualheight,\n",
    "        linewidth=1, \n",
    "        edgecolor='blue', \n",
    "        facecolor='none'\n",
    "    ))\n",
    "\n",
    "    textDisplay = str(class_list[classindex])\n",
    "    ax.text(x_lowerleft,y_lowerleft,textDisplay, bbox=dict(facecolor='w', alpha=0.5, linewidth=0))\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import shutil\n",
    "import os\n",
    "\n",
    "# writes the images to disk under images/\n",
    "# writes the labels as a txt file under labels/\n",
    "def prepare_dataset(dataset, train: bool, root: str, start_index: int, stop_index: int):\n",
    "    os.makedirs(root, exist_ok = True)\n",
    "    train_path = \"train\" if train else \"val\"\n",
    "    label_dir = f\"{root}/labels/{train_path}\"\n",
    "    image_dir = f\"{root}/images/{train_path}\"\n",
    "    os.makedirs(label_dir, exist_ok = True)\n",
    "    os.makedirs(image_dir, exist_ok = True)\n",
    "\n",
    "    for i in tqdm(range(start_index, stop_index), desc=f\"Writing YOLO {train_path.title()} Dataset\"):\n",
    "        image, labels = dataset[i]\n",
    "        filename_base = str(i).zfill(10)\n",
    "        label_filename = f\"{label_dir}/{filename_base}.txt\"\n",
    "        image_filename = f\"{image_dir}/{filename_base}.png\"\n",
    "\n",
    "        with open(label_filename, \"w\") as f:\n",
    "            line = \"\"\n",
    "            f.write(\"\\n\".join(f\"{x[0]} {x[1]} {x[2]} {x[3]} {x[4]}\" for x in labels))\n",
    "        cv2.imwrite(image_filename, image, [cv2.IMWRITE_PNG_COMPRESSION, 9])\n",
    "\n",
    "if os.path.exists(root):\n",
    "    shutil.rmtree(root)\n",
    "\n",
    "prepare_dataset(wideband, train=True, root=root, start_index=1, stop_index = num_train)\n",
    "prepare_dataset(wideband, train=False, root=root, start_index=num_train, stop_index = num_train + num_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataset yaml file for ultralytics\n",
    "import yaml\n",
    "import torch\n",
    "\n",
    "config_name = \"wideband_detector_yolo.yaml\"\n",
    "classes = {v: k for v,k in enumerate(class_list)}\n",
    "\n",
    "yolo_config = dict(\n",
    "    path = \"wideband_detector_example\",\n",
    "    train = \"images/train\",\n",
    "    val = \"images/val\",\n",
    "    nc = num_classes,\n",
    "    names = classes\n",
    ")\n",
    "\n",
    "with open(config_name, 'w+') as file:\n",
    "    yaml.dump(yolo_config, file, default_flow_style=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the YOLO Model\n",
    "\n",
    "Below we download the `yolo8x.pt` model (YOLOv8) if it does not exist already, and then instantiate the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import os\n",
    "\n",
    "url = \"https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x.pt\"\n",
    "model_filepath = \"yolov8x.pt\"\n",
    "\n",
    "if not os.path.exists(model_filepath):\n",
    "    print(f\"Downloading yolov8x.pt from {url}...\")\n",
    "    response = requests.get(url)\n",
    "    if response.status_code == 200:\n",
    "        # Write the content to the local file\n",
    "        with open(model_filepath, 'wb') as file:\n",
    "            file.write(response.content)\n",
    "        print(f\"File downloaded and saved to: {model_filepath}\")\n",
    "    else:\n",
    "        print(f\"Failed to download file. Status code: {response.status_code}\")\n",
    "else:\n",
    "    print(f\"{model_filepath} exists\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instantiate the model\n",
    "from ultralytics import YOLO\n",
    "\n",
    "model = YOLO(model_filepath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train\n",
    "\n",
    "Train the YOLO model. See [Ultralytics Train](https://docs.ultralytics.com/modes/train/#train-settings) for training hyperparameter options."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = model.train(\n",
    "    data=config_name, \n",
    "    epochs=20, \n",
    "    batch=1,\n",
    "    imgsz=fft_size,\n",
    "    device=0 if torch.cuda.is_available() else \"cpu\",\n",
    "    workers=1,\n",
    "    project=\"yolo\",\n",
    "    name=\"wideband_detector_example\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# View the training progress and performance\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "import os\n",
    "\n",
    "results_img = cv2.imread(os.path.join(results.save_dir, \"results.png\"))\n",
    "plt.figure(figsize = (10,20))\n",
    "plt.imshow(results_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "We can now use the model to detect signals on spectrogram images. Try running more epochs or a larger dataset for improved performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label = cv2.imread(os.path.join(results.save_dir, \"val_batch1_labels.jpg\"))\n",
    "pred = cv2.imread(os.path.join(results.save_dir, \"val_batch1_pred.jpg\"))\n",
    "\n",
    "f, ax = plt.subplots(1, 2, figsize=(15, 9))\n",
    "ax[0].imshow(label)\n",
    "ax[0].set_title(\"Label\")\n",
    "ax[1].imshow(pred)\n",
    "ax[1].set_title(\"Prediction\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
